{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "from efficientnet_pytorch import EfficientNet\n",
    "from PIL import Image, ImageDraw, ImageFont\n",
    "import numpy as np\n",
    "from pathlib import Path"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "metadata": {},
   "outputs": [],
   "source": [
    "# you can find a pretrained model at model/b3.pth\n",
    "MODEL_F = '/home/sharif/Documents/commai-challenge/model/b0.pth'\n",
    "# directory with the numpy optical flow images you want to use for inference\n",
    "OF_NPY_DIR = '/home/sharif/Documents/RAFT/test_predictions'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "metadata": {},
   "outputs": [],
   "source": [
    "# check if cuda is available\n",
    "device = 'cuda' if torch.cuda.is_available() else 'cpu'"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Load Model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "metadata": {},
   "outputs": [],
   "source": [
    "V = 0     # what version of efficientnet did you use\n",
    "IN_C = 2  # number of input channels\n",
    "NUM_C = 1 # number of classes to predict"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Loaded pretrained weights for efficientnet-b0\n"
     ]
    }
   ],
   "source": [
    "model = EfficientNet.from_pretrained(f'efficientnet-b{V}', in_channels=IN_C, num_classes=NUM_C)\n",
    "state = torch.load(MODEL_F)\n",
    "model.load_state_dict(state)\n",
    "model.to(device);"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "metadata": {},
   "outputs": [],
   "source": [
    "def inference(of_f):\n",
    "    of = np.load(of_f)\n",
    "    i = torch.from_numpy(of).to(device)\n",
    "    pred = model(i)\n",
    "    del i\n",
    "    torch.cuda.empty_cache()\n",
    "    return pred"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# loop over all files in directory and predict\n",
    "for f in Path(OF_NPY_DIR).glob('*.npy'):\n",
    "    y_hat = inference(f).item()\n",
    "    print(f'{f.name}: {round(y_hat, 2)}')"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
